{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d999c38-2087-4250-b6a4-a30cf8b44ec0",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T13:11:38.997916Z",
     "iopub.status.busy": "2023-07-05T13:11:38.997587Z",
     "iopub.status.idle": "2023-07-05T13:11:39.001928Z",
     "shell.execute_reply": "2023-07-05T13:11:39.001429Z",
     "shell.execute_reply.started": "2023-07-05T13:11:38.997898Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import os.path as osp\n",
    "import torch\n",
    "import numpy as np\n",
    "import mmcv\n",
    "import cv2\n",
    "from mmengine.utils import track_iter_progress"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfa9bf9b-dc2c-4803-a034-8ae8778113e0",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:42:15.884465Z",
     "iopub.status.busy": "2023-07-05T12:42:15.884167Z",
     "iopub.status.idle": "2023-07-05T12:42:19.774569Z",
     "shell.execute_reply": "2023-07-05T12:42:19.774020Z",
     "shell.execute_reply.started": "2023-07-05T12:42:15.884448Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# download example videos\n",
    "from mmengine.utils import mkdir_or_exist\n",
    "mkdir_or_exist('resources')\n",
    "! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tom.mp4 \n",
    "! wget -O resources/teacher_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/idol_producer.mp4 \n",
    "# ! wget -O resources/student_video.mp4 https://download.openmmlab.com/mmpose/v1/projects/just_dance/tsinghua_30fps.mp4 \n",
    "\n",
    "student_video = 'resources/student_video.mp4'\n",
    "teacher_video = 'resources/teacher_video.mp4'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "652b6b91-e1c0-461b-90e5-653bc35ec380",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:42:20.693931Z",
     "iopub.status.busy": "2023-07-05T12:42:20.693353Z",
     "iopub.status.idle": "2023-07-05T12:43:14.533985Z",
     "shell.execute_reply": "2023-07-05T12:43:14.533431Z",
     "shell.execute_reply.started": "2023-07-05T12:42:20.693910Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# convert the fps of videos to 30\n",
    "from mmcv import VideoReader\n",
    "\n",
    "if VideoReader(student_video) != 30:\n",
    "    # ffmpeg is required to convert the video fps\n",
    "    # which can be installed via `sudo apt install ffmpeg` on ubuntu\n",
    "    student_video_30fps = student_video.replace(\n",
    "        f\".{student_video.rsplit('.', 1)[1]}\",\n",
    "        f\"_30fps.{student_video.rsplit('.', 1)[1]}\"\n",
    "    )\n",
    "    !ffmpeg -i {student_video} -vf \"minterpolate='fps=30'\" {student_video_30fps}\n",
    "    student_video = student_video_30fps\n",
    "    \n",
    "if VideoReader(teacher_video) != 30:\n",
    "    teacher_video_30fps = teacher_video.replace(\n",
    "        f\".{teacher_video.rsplit('.', 1)[1]}\",\n",
    "        f\"_30fps.{teacher_video.rsplit('.', 1)[1]}\"\n",
    "    )\n",
    "    !ffmpeg -i {teacher_video} -vf \"minterpolate='fps=30'\" {teacher_video_30fps}\n",
    "    teacher_video = teacher_video_30fps    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4e141d-ee4a-4e06-a380-230418c9b936",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:45:01.672054Z",
     "iopub.status.busy": "2023-07-05T12:45:01.671727Z",
     "iopub.status.idle": "2023-07-05T12:45:02.417026Z",
     "shell.execute_reply": "2023-07-05T12:45:02.416567Z",
     "shell.execute_reply.started": "2023-07-05T12:45:01.672035Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# init pose estimator\n",
    "from mmpose.apis.inferencers import Pose2DInferencer\n",
    "pose_estimator = Pose2DInferencer(\n",
    "    'rtmpose-t_8xb256-420e_aic-coco-256x192',\n",
    "    det_model='configs/rtmdet-nano_one-person.py',\n",
    "    det_weights='https://download.openmmlab.com/mmpose/v1/projects/' \n",
    "    'rtmpose/rtmdet_nano_8xb32-100e_coco-obj365-person-05d8511e.pth'\n",
    ")\n",
    "pose_estimator.model.test_cfg['flip_test'] = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "879ba5c0-4d2d-4cca-92d7-d4f94e04a821",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:45:05.192437Z",
     "iopub.status.busy": "2023-07-05T12:45:05.191982Z",
     "iopub.status.idle": "2023-07-05T12:45:05.197379Z",
     "shell.execute_reply": "2023-07-05T12:45:05.196780Z",
     "shell.execute_reply.started": "2023-07-05T12:45:05.192417Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def get_keypoints_from_frame(image, pose_estimator):\n",
    "    \"\"\"Extract keypoints from a single video frame.\"\"\"\n",
    "\n",
    "    det_results = pose_estimator.detector(\n",
    "        image, return_datasample=True)['predictions']\n",
    "    pred_instance = det_results[0].pred_instances.numpy()\n",
    "\n",
    "    if len(pred_instance) == 0 or pred_instance.scores[0] < 0.2:\n",
    "        return np.zeros((1, 17, 3), dtype=np.float32)\n",
    "\n",
    "    data_info = dict(\n",
    "        img=image,\n",
    "        bbox=pred_instance.bboxes[:1],\n",
    "        bbox_score=pred_instance.scores[:1])\n",
    "\n",
    "    data_info.update(pose_estimator.model.dataset_meta)\n",
    "    data = pose_estimator.collate_fn(\n",
    "        [pose_estimator.pipeline(data_info)])\n",
    "\n",
    "    # custom forward\n",
    "    data = pose_estimator.model.data_preprocessor(data, False)\n",
    "    feats = pose_estimator.model.extract_feat(data['inputs'])\n",
    "    pred_instances = pose_estimator.model.head.predict(\n",
    "        feats,\n",
    "        data['data_samples'],\n",
    "        test_cfg=pose_estimator.model.test_cfg)[0]\n",
    "    keypoints = np.concatenate(\n",
    "        (pred_instances.keypoints, pred_instances.keypoint_scores[...,\n",
    "                                                                  None]),\n",
    "        axis=-1)\n",
    "\n",
    "    return keypoints    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31e5bd4c-4c2b-4fe0-b64c-1afed67b7688",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:47:55.564788Z",
     "iopub.status.busy": "2023-07-05T12:47:55.564450Z",
     "iopub.status.idle": "2023-07-05T12:49:37.222662Z",
     "shell.execute_reply": "2023-07-05T12:49:37.222028Z",
     "shell.execute_reply.started": "2023-07-05T12:47:55.564770Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# pose estimation in two videos\n",
    "student_poses, teacher_poses = [], []\n",
    "for frame in VideoReader(student_video):\n",
    "    student_poses.append(get_keypoints_from_frame(frame, pose_estimator))\n",
    "for frame in VideoReader(teacher_video):\n",
    "    teacher_poses.append(get_keypoints_from_frame(frame, pose_estimator))\n",
    "    \n",
    "student_poses = np.concatenate(student_poses)\n",
    "teacher_poses = np.concatenate(teacher_poses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38a8d7a5-17ed-4ce2-bb8b-d1637cb49578",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:55:09.342432Z",
     "iopub.status.busy": "2023-07-05T12:55:09.342185Z",
     "iopub.status.idle": "2023-07-05T12:55:09.350522Z",
     "shell.execute_reply": "2023-07-05T12:55:09.350099Z",
     "shell.execute_reply.started": "2023-07-05T12:55:09.342416Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "valid_indices = np.array([0] + list(range(5, 17)))\n",
    "\n",
    "@torch.no_grad()\n",
    "def _calculate_similarity(tch_kpts: np.ndarray, stu_kpts: np.ndarray):\n",
    "\n",
    "    stu_kpts = torch.from_numpy(stu_kpts[:, None, valid_indices])\n",
    "    tch_kpts = torch.from_numpy(tch_kpts[None, :, valid_indices])\n",
    "    stu_kpts = stu_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],\n",
    "                               stu_kpts.shape[2], 3)\n",
    "    tch_kpts = tch_kpts.expand(stu_kpts.shape[0], tch_kpts.shape[1],\n",
    "                               stu_kpts.shape[2], 3)\n",
    "\n",
    "    matrix = torch.stack((stu_kpts, tch_kpts), dim=4)\n",
    "    if torch.cuda.is_available():\n",
    "        matrix = matrix.cuda()\n",
    "    # only consider visible keypoints\n",
    "    mask = torch.logical_and(matrix[:, :, :, 2, 0] > 0.3,\n",
    "                             matrix[:, :, :, 2, 1] > 0.3)\n",
    "    matrix[~mask] = 0.0\n",
    "\n",
    "    matrix_ = matrix.clone()\n",
    "    matrix_[matrix == 0] = 256\n",
    "    x_min = matrix_.narrow(3, 0, 1).min(dim=2).values\n",
    "    y_min = matrix_.narrow(3, 1, 1).min(dim=2).values\n",
    "    matrix_ = matrix.clone()\n",
    "    x_max = matrix_.narrow(3, 0, 1).max(dim=2).values\n",
    "    y_max = matrix_.narrow(3, 1, 1).max(dim=2).values\n",
    "\n",
    "    matrix_ = matrix.clone()\n",
    "    matrix_[:, :, :, 0] = (matrix_[:, :, :, 0] - x_min) / (\n",
    "        x_max - x_min + 1e-4)\n",
    "    matrix_[:, :, :, 1] = (matrix_[:, :, :, 1] - y_min) / (\n",
    "        y_max - y_min + 1e-4)\n",
    "    matrix_[:, :, :, 2] = (matrix_[:, :, :, 2] > 0.3).float()\n",
    "    xy_dist = matrix_[..., :2, 0] - matrix_[..., :2, 1]\n",
    "    score = matrix_[..., 2, 0] * matrix_[..., 2, 1]\n",
    "\n",
    "    similarity = (torch.exp(-50 * xy_dist.pow(2).sum(dim=-1)) *\n",
    "                  score).sum(dim=-1) / (\n",
    "                      score.sum(dim=-1) + 1e-6)\n",
    "    num_visible_kpts = score.sum(dim=-1)\n",
    "    similarity = similarity * torch.log(\n",
    "        (1 + (num_visible_kpts - 1) * 10).clamp(min=1)) / np.log(161)\n",
    "\n",
    "    similarity[similarity.isnan()] = 0\n",
    "\n",
    "    return similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "658bcf89-df06-4c73-9323-8973a49c14c3",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-05T12:55:31.978675Z",
     "iopub.status.busy": "2023-07-05T12:55:31.978219Z",
     "iopub.status.idle": "2023-07-05T12:55:32.149624Z",
     "shell.execute_reply": "2023-07-05T12:55:32.148568Z",
     "shell.execute_reply.started": "2023-07-05T12:55:31.978657Z"
    }
   },
   "outputs": [],
   "source": [
    "# compute similarity without flip\n",
    "similarity1 = _calculate_similarity(teacher_poses, student_poses)\n",
    "\n",
    "# compute similarity with flip\n",
    "flip_indices = np.array(\n",
    "    [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15])\n",
    "student_poses_flip = student_poses[:, flip_indices]\n",
    "student_poses_flip[..., 0] = 191.5 - student_poses_flip[..., 0]\n",
    "similarity2 = _calculate_similarity(teacher_poses, student_poses_flip)\n",
    "\n",
    "# select the larger similarity\n",
    "similarity = torch.stack((similarity1, similarity2)).max(dim=0).values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f981410d-4585-47c1-98c0-6946f948487d",
   "metadata": {
    "ExecutionIndicator": {
     "show": false
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:55:57.321845Z",
     "iopub.status.busy": "2023-07-05T12:55:57.321530Z",
     "iopub.status.idle": "2023-07-05T12:55:57.582879Z",
     "shell.execute_reply": "2023-07-05T12:55:57.582425Z",
     "shell.execute_reply.started": "2023-07-05T12:55:57.321826Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# visualize the similarity\n",
    "plt.imshow(similarity.cpu().numpy())\n",
    "\n",
    "# there is an apparent diagonal in the figure\n",
    "# we can select matched video snippets with this diagonal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13c189e5-fc53-46a2-9057-f0f2ffc1f46d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-05T12:58:16.913855Z",
     "iopub.status.busy": "2023-07-05T12:58:16.913529Z",
     "iopub.status.idle": "2023-07-05T12:58:16.919972Z",
     "shell.execute_reply": "2023-07-05T12:58:16.919005Z",
     "shell.execute_reply.started": "2023-07-05T12:58:16.913837Z"
    }
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def select_piece_from_similarity(similarity):\n",
    "    m, n = similarity.size()\n",
    "    row_indices = torch.arange(m).view(-1, 1).expand(m, n).to(similarity)\n",
    "    col_indices = torch.arange(n).view(1, -1).expand(m, n).to(similarity)\n",
    "    diagonal_indices = similarity.size(0) - 1 - row_indices + col_indices\n",
    "    unique_diagonal_indices, inverse_indices = torch.unique(\n",
    "        diagonal_indices, return_inverse=True)\n",
    "\n",
    "    diagonal_sums_list = torch.zeros(\n",
    "        unique_diagonal_indices.size(0),\n",
    "        dtype=similarity.dtype,\n",
    "        device=similarity.device)\n",
    "    diagonal_sums_list.scatter_add_(0, inverse_indices.view(-1),\n",
    "                                    similarity.view(-1))\n",
    "    diagonal_sums_list[:min(m, n) // 4] = 0\n",
    "    diagonal_sums_list[-min(m, n) // 4:] = 0\n",
    "    index = diagonal_sums_list.argmax().item()\n",
    "\n",
    "    similarity_smooth = torch.nn.functional.max_pool2d(\n",
    "        similarity[None], (1, 11), stride=(1, 1), padding=(0, 5))[0]\n",
    "    similarity_vec = similarity_smooth.diagonal(offset=index - m +\n",
    "                                                1).cpu().numpy()\n",
    "\n",
    "    stu_start = max(0, m - 1 - index)\n",
    "    tch_start = max(0, index - m + 1)\n",
    "\n",
    "    return dict(\n",
    "        stu_start=stu_start,\n",
    "        tch_start=tch_start,\n",
    "        length=len(similarity_vec),\n",
    "        similarity=similarity_vec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c0e19df-949d-471d-804d-409b3b9ddf7d",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T12:58:44.860190Z",
     "iopub.status.busy": "2023-07-05T12:58:44.859878Z",
     "iopub.status.idle": "2023-07-05T12:58:44.888465Z",
     "shell.execute_reply": "2023-07-05T12:58:44.887917Z",
     "shell.execute_reply.started": "2023-07-05T12:58:44.860173Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "matched_piece_info = select_piece_from_similarity(similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51b0a2bd-253c-4a8f-a82a-263e18a4703e",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T13:01:19.061408Z",
     "iopub.status.busy": "2023-07-05T13:01:19.060857Z",
     "iopub.status.idle": "2023-07-05T13:01:19.293742Z",
     "shell.execute_reply": "2023-07-05T13:01:19.293298Z",
     "shell.execute_reply.started": "2023-07-05T13:01:19.061378Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.imshow(similarity.cpu().numpy())\n",
    "plt.plot((matched_piece_info['tch_start'], \n",
    "          matched_piece_info['tch_start']+matched_piece_info['length']-1),\n",
    "         (matched_piece_info['stu_start'],\n",
    "          matched_piece_info['stu_start']+matched_piece_info['length']-1), 'r')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ffcde4e7-ff50-483a-b515-604c1d8f121a",
   "metadata": {},
   "source": [
    "# Generate Output Video"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72171a0c-ab33-45bb-b84c-b15f0816ed3a",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T13:11:50.063595Z",
     "iopub.status.busy": "2023-07-05T13:11:50.063259Z",
     "iopub.status.idle": "2023-07-05T13:11:50.070929Z",
     "shell.execute_reply": "2023-07-05T13:11:50.070411Z",
     "shell.execute_reply.started": "2023-07-05T13:11:50.063574Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "def resize_image_to_fixed_height(image: np.ndarray,\n",
    "                                 fixed_height: int) -> np.ndarray:\n",
    "    \"\"\"Resizes an input image to a specified fixed height while maintaining its\n",
    "    aspect ratio.\n",
    "\n",
    "    Args:\n",
    "        image (np.ndarray): Input image as a numpy array [H, W, C]\n",
    "        fixed_height (int): Desired fixed height of the output image.\n",
    "\n",
    "    Returns:\n",
    "        Resized image as a numpy array (fixed_height, new_width, channels).\n",
    "    \"\"\"\n",
    "    original_height, original_width = image.shape[:2]\n",
    "\n",
    "    scale_ratio = fixed_height / original_height\n",
    "    new_width = int(original_width * scale_ratio)\n",
    "    resized_image = cv2.resize(image, (new_width, fixed_height))\n",
    "\n",
    "    return resized_image\n",
    "\n",
    "def blend_images(img1: np.ndarray,\n",
    "                 img2: np.ndarray,\n",
    "                 blend_ratios: Tuple[float, float] = (1, 1)) -> np.ndarray:\n",
    "    \"\"\"Blends two input images with specified blend ratios.\n",
    "\n",
    "    Args:\n",
    "        img1 (np.ndarray): First input image as a numpy array [H, W, C].\n",
    "        img2 (np.ndarray): Second input image as a numpy array [H, W, C]\n",
    "        blend_ratios (tuple): A tuple of two floats representing the blend\n",
    "            ratios for the two input images.\n",
    "\n",
    "    Returns:\n",
    "        Blended image as a numpy array [H, W, C]\n",
    "    \"\"\"\n",
    "\n",
    "    def normalize_image(image: np.ndarray) -> np.ndarray:\n",
    "        if image.dtype == np.uint8:\n",
    "            return image.astype(np.float32) / 255.0\n",
    "        return image\n",
    "\n",
    "    img1 = normalize_image(img1)\n",
    "    img2 = normalize_image(img2)\n",
    "\n",
    "    blended_image = img1 * blend_ratios[0] + img2 * blend_ratios[1]\n",
    "    blended_image = blended_image.clip(min=0, max=1)\n",
    "    blended_image = (blended_image * 255).astype(np.uint8)\n",
    "\n",
    "    return blended_image\n",
    "\n",
    "def get_smoothed_kpt(kpts, index, sigma=5):\n",
    "    \"\"\"Smooths keypoints using a Gaussian filter.\"\"\"\n",
    "    assert kpts.shape[1] == 17\n",
    "    assert kpts.shape[2] == 3\n",
    "    assert sigma % 2 == 1\n",
    "\n",
    "    num_kpts = len(kpts)\n",
    "\n",
    "    start_idx = max(0, index - sigma // 2)\n",
    "    end_idx = min(num_kpts, index + sigma // 2 + 1)\n",
    "\n",
    "    # Extract a piece of the keypoints array to apply the filter\n",
    "    piece = kpts[start_idx:end_idx].copy()\n",
    "    original_kpt = kpts[index]\n",
    "\n",
    "    # Split the piece into coordinates and scores\n",
    "    coords, scores = piece[..., :2], piece[..., 2]\n",
    "\n",
    "    # Calculate the Gaussian ratio for each keypoint\n",
    "    gaussian_ratio = np.arange(len(scores)) + start_idx - index\n",
    "    gaussian_ratio = np.exp(-gaussian_ratio**2 / 2)\n",
    "\n",
    "    # Update scores using the Gaussian ratio\n",
    "    scores *= gaussian_ratio[:, None]\n",
    "\n",
    "    # Compute the smoothed coordinates\n",
    "    smoothed_coords = (coords * scores[..., None]).sum(axis=0) / (\n",
    "        scores[..., None].sum(axis=0) + 1e-4)\n",
    "\n",
    "    original_kpt[..., :2] = smoothed_coords\n",
    "\n",
    "    return original_kpt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "609b5adc-e176-4bf9-b9a4-506f72440017",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T13:12:46.198835Z",
     "iopub.status.busy": "2023-07-05T13:12:46.198268Z",
     "iopub.status.idle": "2023-07-05T13:12:46.202273Z",
     "shell.execute_reply": "2023-07-05T13:12:46.200881Z",
     "shell.execute_reply.started": "2023-07-05T13:12:46.198815Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "score, last_vis_score = 0, 0\n",
    "video_writer = None\n",
    "output_file = 'output.mp4'\n",
    "stu_kpts = student_poses\n",
    "tch_kpts = teacher_poses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a264405a-5d50-49de-8637-2d1f67cb0a70",
   "metadata": {
    "ExecutionIndicator": {
     "show": true
    },
    "execution": {
     "iopub.execute_input": "2023-07-05T13:13:11.334760Z",
     "iopub.status.busy": "2023-07-05T13:13:11.334433Z",
     "iopub.status.idle": "2023-07-05T13:13:17.264181Z",
     "shell.execute_reply": "2023-07-05T13:13:17.262931Z",
     "shell.execute_reply.started": "2023-07-05T13:13:11.334742Z"
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mmengine.structures import InstanceData\n",
    "\n",
    "tch_video_reader = VideoReader(teacher_video)\n",
    "stu_video_reader = VideoReader(student_video)\n",
    "for _ in range(matched_piece_info['tch_start']):\n",
    "    _ = next(tch_video_reader)\n",
    "for _ in range(matched_piece_info['stu_start']):\n",
    "    _ = next(stu_video_reader)\n",
    "    \n",
    "for i in track_iter_progress(range(matched_piece_info['length'])):\n",
    "    tch_frame = mmcv.bgr2rgb(next(tch_video_reader))\n",
    "    stu_frame = mmcv.bgr2rgb(next(stu_video_reader))\n",
    "    tch_frame = resize_image_to_fixed_height(tch_frame, 300)\n",
    "    stu_frame = resize_image_to_fixed_height(stu_frame, 300)\n",
    "\n",
    "    stu_kpt = get_smoothed_kpt(stu_kpts, matched_piece_info['stu_start'] + i,\n",
    "                               5)\n",
    "    tch_kpt = get_smoothed_kpt(tch_kpts, matched_piece_info['tch_start'] + i,\n",
    "                               5)\n",
    "\n",
    "    # draw pose\n",
    "    stu_kpt[..., 1] += (300 - 256)\n",
    "    tch_kpt[..., 0] += (256 - 192)\n",
    "    tch_kpt[..., 1] += (300 - 256)\n",
    "    stu_inst = InstanceData(\n",
    "        keypoints=stu_kpt[None, :, :2],\n",
    "        keypoint_scores=stu_kpt[None, :, 2])\n",
    "    tch_inst = InstanceData(\n",
    "        keypoints=tch_kpt[None, :, :2],\n",
    "        keypoint_scores=tch_kpt[None, :, 2])\n",
    "    \n",
    "    stu_out_img = pose_estimator.visualizer._draw_instances_kpts(\n",
    "        np.zeros((300, 256, 3)), stu_inst)\n",
    "    tch_out_img = pose_estimator.visualizer._draw_instances_kpts(\n",
    "        np.zeros((300, 256, 3)), tch_inst)\n",
    "    out_img = blend_images(\n",
    "        stu_out_img, tch_out_img, blend_ratios=(1, 0.3))\n",
    "\n",
    "    # draw score\n",
    "    score_frame = matched_piece_info['similarity'][i]\n",
    "    score += score_frame * 1000\n",
    "    if score - last_vis_score > 1500:\n",
    "        last_vis_score = score\n",
    "    pose_estimator.visualizer.set_image(out_img)\n",
    "    pose_estimator.visualizer.draw_texts(\n",
    "        'score: ', (60, 30),\n",
    "        font_sizes=15,\n",
    "        colors=(255, 255, 255),\n",
    "        vertical_alignments='bottom')\n",
    "    pose_estimator.visualizer.draw_texts(\n",
    "        f'{int(last_vis_score)}', (115, 30),\n",
    "        font_sizes=30 * max(0.4, score_frame),\n",
    "        colors=(255, 255, 255),\n",
    "        vertical_alignments='bottom')\n",
    "    out_img = pose_estimator.visualizer.get_image()   \n",
    "    \n",
    "    # concatenate\n",
    "    concatenated_image = np.hstack((stu_frame, out_img, tch_frame))\n",
    "    if video_writer is None:\n",
    "        video_writer = cv2.VideoWriter(output_file,\n",
    "                                       cv2.VideoWriter_fourcc(*'mp4v'),\n",
    "                                       30,\n",
    "                                       (concatenated_image.shape[1],\n",
    "                                        concatenated_image.shape[0]))\n",
    "    video_writer.write(mmcv.rgb2bgr(concatenated_image))\n",
    "\n",
    "  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "745fdd75-6ed4-4cae-9f21-c2cd486ee918",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-05T13:13:18.704492Z",
     "iopub.status.busy": "2023-07-05T13:13:18.704179Z",
     "iopub.status.idle": "2023-07-05T13:13:18.714843Z",
     "shell.execute_reply": "2023-07-05T13:13:18.713866Z",
     "shell.execute_reply.started": "2023-07-05T13:13:18.704472Z"
    }
   },
   "outputs": [],
   "source": [
    "if video_writer is not None:\n",
    "    video_writer.release()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cb0bc99-ca19-44f1-bc0a-38e14afa980f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
